/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    FTNode.java
 *    Copyright (C) 2007 University of Porto, Porto, Portugal
 *
 */

package weka.classifiers.trees.ft;

import weka.classifiers.functions.SimpleLinearRegression;
import weka.classifiers.trees.j48.BinC45ModelSelection;
import weka.classifiers.trees.j48.BinC45Split;
import weka.classifiers.trees.j48.C45Split;
import weka.classifiers.trees.j48.ClassifierSplitModel;
import weka.classifiers.trees.j48.Distribution;
import weka.classifiers.trees.j48.ModelSelection;
import weka.classifiers.trees.j48.Stats;
import weka.classifiers.trees.lmt.LogisticBase;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;

import java.util.Vector;

/**
 * Abstract class for Functional tree structure.
 * 
 * @author Jo\~{a}o Gama
 * @author Carlos Ferreira
 * 
 * @version $Revision: 1.4 $
 */
public abstract class FTtree extends LogisticBase {

	/** for serialization */
	static final long serialVersionUID = 1862737145870398755L;

	/** Total number of training instances. */
	protected double m_totalInstanceWeight;

	/** Node id */
	protected int m_id;

	/** ID of logistic model at leaf */
	protected int m_leafModelNum;

	/** minimum number of instances at which a node is considered for splitting */
	protected int m_minNumInstances;

	/** ModelSelection object (for splitting) */
	protected ModelSelection m_modelSelection;

	/** Filter to convert nominal attributes to binary */
	protected NominalToBinary m_nominalToBinary;

	/**
	 * Simple regression functions fit by LogitBoost at higher levels in the
	 * tree
	 */
	protected SimpleLinearRegression[][] m_higherRegressions;

	/**
	 * Number of simple regression functions fit by LogitBoost at higher levels
	 * in the tree
	 */
	protected int m_numHigherRegressions = 0;

	/** Number of instances at the node */
	protected int m_numInstances;

	/** The ClassifierSplitModel (for splitting) */
	protected ClassifierSplitModel m_localModel;

	/** Auxiliary copy ClassifierSplitModel (for splitting) */
	protected ClassifierSplitModel m_auxLocalModel;

	/** Array of children of the node */
	protected FTtree[] m_sons;

	/** Stores leaf class value */
	protected int m_leafclass;

	/** True if node is leaf */
	protected boolean m_isLeaf;

	/** True if node has or splits on constructor */
	protected boolean m_hasConstr = true;

	/** Constructor error */
	protected double m_constError = 0;

	/** Confidence level */
	protected float m_CF = 0.10f;

	/**
	 * Method for building a Functional Tree (only called for the root node).
	 * Grows an initial Functional Tree.
	 * 
	 * @param data
	 *            the data to train with
	 * @throws Exception
	 *             if something goes wrong
	 */
	public abstract void buildClassifier(Instances data) throws Exception;

	/**
	 * Abstract method for building the tree structure. Builds a logistic model,
	 * splits the node and recursively builds tree for child nodes.
	 * 
	 * @param data
	 *            the training data passed on to this node
	 * @param higherRegressions
	 *            An array of regression functions produced by LogitBoost at
	 *            higher levels in the tree. They represent a logistic
	 *            regression model that is refined locally at this node.
	 * @param totalInstanceWeight
	 *            the total number of training examples
	 * @param higherNumParameters
	 *            effective number of parameters in the logistic regression
	 *            model built in parent nodes
	 * @throws Exception
	 *             if something goes wrong
	 */
	public abstract void buildTree(Instances data,
			SimpleLinearRegression[][] higherRegressions,
			double totalInstanceWeight, double higherNumParameters)
			throws Exception;

	/**
	 * Abstract Method that prunes a tree using C4.5 pruning procedure.
	 * 
	 * @exception Exception
	 *                if something goes wrong
	 */
	public abstract double prune() throws Exception;

	/**
	 * Inserts new attributes in current dataset or instance
	 * 
	 * @exception Exception
	 *                if something goes wrong
	 */
	protected Instances insertNewAttr(Instances data) throws Exception {

		int i;
		for (i = 0; i < data.classAttribute().numValues(); i++) {
			data.insertAttributeAt(new Attribute("N" + i), i);
		}
		return data;
	}

	/**
	 * Removes extended attributes in current dataset or instance
	 * 
	 * @exception Exception
	 *                if something goes wrong
	 */
	protected Instances removeExtAttributes(Instances data) throws Exception {

		for (int i = 0; i < data.classAttribute().numValues(); i++) {
			data.deleteAttributeAt(0);
		}
		return data;
	}

	/**
	 * Computes estimated errors for tree.
	 */
	protected double getEstimatedErrors() {

		double errors = 0;
		int i;

		if (m_isLeaf)
			return getEstimatedErrorsForDistribution(m_localModel
					.distribution());
		else {
			for (i = 0; i < m_sons.length; i++)
				errors = errors + m_sons[i].getEstimatedErrors();

			return errors;
		}
	}

	/**
	 * Computes estimated errors for one branch.
	 * 
	 * @exception Exception
	 *                if something goes wrong
	 */
	protected double getEstimatedErrorsForBranch(Instances data)
			throws Exception {

		Instances[] localInstances;
		double errors = 0;
		int i;

		if (m_isLeaf)
			return getEstimatedErrorsForDistribution(new Distribution(data));
		else {
			Distribution savedDist = m_localModel.distribution();
			m_localModel.resetDistribution(data);
			localInstances = (Instances[]) m_localModel.split(data);
			// m_localModel.m_distribution=savedDist;
			for (i = 0; i < m_sons.length; i++)
				errors = errors
						+ m_sons[i]
								.getEstimatedErrorsForBranch(localInstances[i]);
			return errors;
		}
	}

	/**
	 * Computes estimated errors for leaf.
	 */
	protected double getEstimatedErrorsForDistribution(
			Distribution theDistribution) {
		double numInc;
		double numTotal;
		if (Utils.eq(theDistribution.total(), 0))
			return 0;
		else// stats.addErrs returns p - numberofincorrect.=p
		{
			numInc = theDistribution.numIncorrect();
			numTotal = theDistribution.total();
			return ((Stats.addErrs(numTotal, numInc, m_CF)) + numInc)
					/ numTotal;
		}

	}

	/**
	 * Computes estimated errors for Constructor Model.
	 */
	protected double getEtimateConstModel(Distribution theDistribution) {
		double numInc;
		double numTotal;
		if (Utils.eq(theDistribution.total(), 0))
			return 0;
		else// stats.addErrs returns p - numberofincorrect.=p
		{
			numTotal = theDistribution.total();
			return ((Stats.addErrs(numTotal, m_constError, m_CF)) + m_constError)
					/ numTotal;
		}
	}

	/**
	 * Method to count the number of inner nodes in the tree
	 * 
	 * @return the number of inner nodes
	 */
	public int getNumInnerNodes() {
		if (m_isLeaf)
			return 0;
		int numNodes = 1;
		for (int i = 0; i < m_sons.length; i++)
			numNodes += m_sons[i].getNumInnerNodes();
		return numNodes;
	}

	/**
	 * Returns the number of leaves in the tree. Leaves are only counted if
	 * their logistic model has changed compared to the one of the parent node.
	 * 
	 * @return the number of leaves
	 */
	public int getNumLeaves() {
		int numLeaves;
		if (!m_isLeaf) {
			numLeaves = 0;
			int numEmptyLeaves = 0;
			for (int i = 0; i < m_sons.length; i++) {
				numLeaves += m_sons[i].getNumLeaves();
				if (m_sons[i].m_isLeaf && !m_sons[i].hasModels())
					numEmptyLeaves++;
			}
			if (numEmptyLeaves > 1) {
				numLeaves -= (numEmptyLeaves - 1);
			}
		} else {
			numLeaves = 1;
		}
		return numLeaves;
	}

	/**
	 * Merges two arrays of regression functions into one
	 * 
	 * @param a1
	 *            one array
	 * @param a2
	 *            the other array
	 * 
	 * @return an array that contains all entries from both input arrays
	 */
	protected SimpleLinearRegression[][] mergeArrays(
			SimpleLinearRegression[][] a1, SimpleLinearRegression[][] a2) {
		int numModels1 = a1[0].length;
		int numModels2 = a2[0].length;

		SimpleLinearRegression[][] result = new SimpleLinearRegression[m_numClasses][numModels1
				+ numModels2];

		for (int i = 0; i < m_numClasses; i++)
			for (int j = 0; j < numModels1; j++) {
				result[i][j] = a1[i][j];
			}
		for (int i = 0; i < m_numClasses; i++)
			for (int j = 0; j < numModels2; j++)
				result[i][j + numModels1] = a2[i][j];
		return result;
	}

	/**
	 * Return a list of all inner nodes in the tree
	 * 
	 * @return the list of nodes
	 */
	public Vector getNodes() {
		Vector nodeList = new Vector();
		getNodes(nodeList);
		return nodeList;
	}

	/**
	 * Fills a list with all inner nodes in the tree
	 * 
	 * @param nodeList
	 *            the list to be filled
	 */
	public void getNodes(Vector nodeList) {
		if (!m_isLeaf) {
			nodeList.add(this);
			for (int i = 0; i < m_sons.length; i++)
				m_sons[i].getNodes(nodeList);
		}
	}

	/**
	 * Returns a numeric version of a set of instances. All nominal attributes
	 * are replaced by binary ones, and the class variable is replaced by a
	 * pseudo-class variable that is used by LogitBoost.
	 */
	protected Instances getNumericData(Instances train) throws Exception {

		Instances filteredData = new Instances(train);
		m_nominalToBinary = new NominalToBinary();
		m_nominalToBinary.setInputFormat(filteredData);
		filteredData = Filter.useFilter(filteredData, m_nominalToBinary);

		return super.getNumericData(filteredData);
	}

	/**
	 * Computes the F-values of LogitBoost for an instance from the current
	 * logistic model at the node Note that this also takes into account the
	 * (partial) logistic model fit at higher levels in the tree.
	 * 
	 * @param instance
	 *            the instance
	 * @return the array of F-values
	 */
	protected double[] getFs(Instance instance) throws Exception {

		double[] pred = new double[m_numClasses];

		// Need to take into account partial model fit at higher levels in the
		// tree (m_higherRegressions)
		// and the part of the model fit at this node (m_regressions).

		// Fs from m_regressions (use method of LogisticBase)
		double[] instanceFs = super.getFs(instance);

		// Fs from m_higherRegressions
		for (int i = 0; i < m_numHigherRegressions; i++) {
			double predSum = 0;
			for (int j = 0; j < m_numClasses; j++) {
				pred[j] = m_higherRegressions[j][i].classifyInstance(instance);
				predSum += pred[j];
			}
			predSum /= m_numClasses;
			for (int j = 0; j < m_numClasses; j++) {
				instanceFs[j] += (pred[j] - predSum) * (m_numClasses - 1)
						/ m_numClasses;
			}
		}
		return instanceFs;
	}

	/**
	 * 
	 * @param <any>
	 *            probsConst
	 */
	public int getConstError(double[] probsConst) {
		return Utils.maxIndex(probsConst);
	}

	/**
	 * Returns true if the logistic regression model at this node has changed
	 * compared to the one at the parent node.
	 * 
	 * @return whether it has changed
	 */
	public boolean hasModels() {
		return (m_numRegressions > 0);
	}

	/**
	 * Returns the class probabilities for an instance according to the logistic
	 * model at the node.
	 * 
	 * @param instance
	 *            the instance
	 * @return the array of probabilities
	 */
	public double[] modelDistributionForInstance(Instance instance)
			throws Exception {

		// make copy and convert nominal attributes
		instance = (Instance) instance.copy();
		m_nominalToBinary.input(instance);
		instance = m_nominalToBinary.output();

		// set numeric pseudo-class
		instance.setDataset(m_numericDataHeader);

		return probs(getFs(instance));
	}

	/**
	 * Returns the class probabilities for an instance given by the Functional
	 * tree.
	 * 
	 * @param instance
	 *            the instance
	 * @return the array of probabilities
	 */
	public abstract double[] distributionForInstance(Instance instance)
			throws Exception;

	/**
	 * Returns a description of the Functional tree (tree structure and logistic
	 * models)
	 * 
	 * @return describing string
	 */
	public String toString() {
		// assign numbers to logistic regression functions at leaves
		assignLeafModelNumbers(0);
		try {
			StringBuffer text = new StringBuffer();

			if (m_isLeaf && !m_hasConstr) {
				text.append(": ");
				text.append("Class" + "=" + m_leafclass);
				// text.append("FT_"+m_leafModelNum+":"+getModelParameters());
			} else {

				if (m_isLeaf && m_hasConstr) {
					text.append(": ");
					text.append("FT_" + m_leafModelNum + ":"
							+ getModelParameters());

				} else {
					dumpTree(0, text);
				}
			}
			text.append("\n\nNumber of Leaves  : \t" + numLeaves() + "\n");
			text.append("\nSize of the Tree : \t" + numNodes() + "\n");

			// This prints logistic models after the tree, comment out if only
			// tree should be printed
			text.append(modelsToString());
			return text.toString();
		} catch (Exception e) {
			return "Can't print logistic model tree";
		}
	}

	/**
	 * Returns the number of leaves (normal count).
	 * 
	 * @return the number of leaves
	 */
	public int numLeaves() {
		if (m_isLeaf)
			return 1;
		int numLeaves = 0;
		for (int i = 0; i < m_sons.length; i++)
			numLeaves += m_sons[i].numLeaves();
		return numLeaves;
	}

	/**
	 * Returns the number of nodes.
	 * 
	 * @return the number of nodes
	 */
	public int numNodes() {
		if (m_isLeaf)
			return 1;
		int numNodes = 1;
		for (int i = 0; i < m_sons.length; i++)
			numNodes += m_sons[i].numNodes();
		return numNodes;
	}

	/**
	 * Returns a string describing the number of LogitBoost iterations performed
	 * at this node, the total number of LogitBoost iterations performed
	 * (including iterations at higher levels in the tree), and the number of
	 * training examples at this node.
	 * 
	 * @return the describing string
	 */
	public String getModelParameters() {

		StringBuffer text = new StringBuffer();
		int numModels = m_numRegressions + m_numHigherRegressions;
		text.append(m_numRegressions + "/" + numModels + " (" + m_numInstances
				+ ")");
		return text.toString();
	}

	/**
	 * Help method for printing tree structure.
	 * 
	 * @throws Exception
	 *             if something goes wrong
	 */
	protected void dumpTree(int depth, StringBuffer text) throws Exception {

		for (int i = 0; i < m_sons.length; i++) {
			text.append("\n");
			for (int j = 0; j < depth; j++)
				text.append("|   ");
			if (m_hasConstr)
				text.append(m_localModel.leftSide(m_train) + "#" + m_id);
			else
				text.append(m_localModel.leftSide(m_train));
			text.append(m_localModel.rightSide(i, m_train));
			if (m_sons[i].m_isLeaf && m_sons[i].m_hasConstr) {
				text.append(": ");
				text.append("FT_" + m_sons[i].m_leafModelNum + ":"
						+ m_sons[i].getModelParameters());
			} else {
				if (m_sons[i].m_isLeaf && !m_sons[i].m_hasConstr) {
					text.append(": ");
					text.append("Class" + "=" + m_sons[i].m_leafclass);
				} else {

					m_sons[i].dumpTree(depth + 1, text);
				}
			}
		}
	}

	/**
	 * Assigns unique IDs to all nodes in the tree
	 */
	public int assignIDs(int lastID) {

		int currLastID = lastID + 1;

		m_id = currLastID;
		if (m_sons != null) {
			for (int i = 0; i < m_sons.length; i++) {
				currLastID = m_sons[i].assignIDs(currLastID);
			}
		}
		return currLastID;
	}

	/**
	 * Assigns numbers to the logistic regression models at the leaves of the
	 * tree
	 */
	public int assignLeafModelNumbers(int leafCounter) {
		if (!m_isLeaf) {
			m_leafModelNum = 0;
			for (int i = 0; i < m_sons.length; i++) {
				leafCounter = m_sons[i].assignLeafModelNumbers(leafCounter);
			}
		} else {
			leafCounter++;
			m_leafModelNum = leafCounter;
		}
		return leafCounter;
	}

	/**
	 * Returns an array containing the coefficients of the logistic regression
	 * function at this node.
	 * 
	 * @return the array of coefficients, first dimension is the class, second
	 *         the attribute.
	 */
	protected double[][] getCoefficients() {

		// Need to take into account partial model fit at higher levels in the
		// tree (m_higherRegressions)
		// and the part of the model fit at this node (m_regressions).

		// get coefficients from m_regressions: use method of LogisticBase
		double[][] coefficients = super.getCoefficients();
		// get coefficients from m_higherRegressions:
		double constFactor = (double) (m_numClasses - 1)
				/ (double) m_numClasses; // (J - 1)/J
		for (int j = 0; j < m_numClasses; j++) {
			for (int i = 0; i < m_numHigherRegressions; i++) {
				double slope = m_higherRegressions[j][i].getSlope();
				double intercept = m_higherRegressions[j][i].getIntercept();
				int attribute = m_higherRegressions[j][i].getAttributeIndex();
				coefficients[j][0] += constFactor * intercept;
				coefficients[j][attribute + 1] += constFactor * slope;
			}
		}

		return coefficients;
	}

	/**
	 * Returns a string describing the logistic regression function at the node.
	 */
	public String modelsToString() {

		StringBuffer text = new StringBuffer();
		if (m_isLeaf && m_hasConstr) {
			text.append("FT_" + m_leafModelNum + ":" + super.toString());

		} else {
			if (!m_isLeaf && m_hasConstr) {
				if (m_modelSelection instanceof BinC45ModelSelection) {
					text.append("FT_N"
							+ ((BinC45Split) m_localModel).attIndex() + "#"
							+ m_id + ":" + super.toString());
				} else {
					text.append("FT_N" + ((C45Split) m_localModel).attIndex()
							+ "#" + m_id + ":" + super.toString());
				}
				for (int i = 0; i < m_sons.length; i++) {
					text.append("\n" + m_sons[i].modelsToString());
				}
			} else {
				if (!m_isLeaf && !m_hasConstr) {
					for (int i = 0; i < m_sons.length; i++) {
						text.append("\n" + m_sons[i].modelsToString());
					}
				} else {
					if (m_isLeaf && !m_hasConstr) {
						text.append("");
					}
				}

			}
		}

		return text.toString();
	}

	/**
	 * Returns graph describing the tree.
	 * 
	 * @throws Exception
	 *             if something goes wrong
	 */
	public String graph() throws Exception {

		StringBuffer text = new StringBuffer();

		assignIDs(-1);
		assignLeafModelNumbers(0);
		text.append("digraph FTree {\n");
		if (m_isLeaf && m_hasConstr) {
			text.append("N" + m_id + " [label=\"FT_" + m_leafModelNum + ":"
					+ getModelParameters() + "\" " + "shape=box style=filled");
			text.append("]\n");
		} else {
			if (m_isLeaf && !m_hasConstr) {
				text.append("N" + m_id + " [label=\"Class=" + m_leafclass
						+ "\" " + "shape=box style=filled");
				text.append("]\n");

			} else {
				text.append("N" + m_id + " [label=\""
						+ m_localModel.leftSide(m_train) + "\" ");
				text.append("]\n");
				graphTree(text);
			}
		}
		return text.toString() + "}\n";
	}

	/**
	 * Helper function for graph description of tree
	 * 
	 * @throws Exception
	 *             if something goes wrong
	 */
	protected void graphTree(StringBuffer text) throws Exception {

		for (int i = 0; i < m_sons.length; i++) {
			text.append("N" + m_id + "->" + "N" + m_sons[i].m_id + " [label=\""
					+ m_localModel.rightSide(i, m_train).trim() + "\"]\n");
			if (m_sons[i].m_isLeaf && m_sons[i].m_hasConstr) {
				text.append("N" + m_sons[i].m_id + " [label=\"FT_"
						+ m_sons[i].m_leafModelNum + ":"
						+ m_sons[i].getModelParameters() + "\" "
						+ "shape=box style=filled");
				text.append("]\n");
			} else {
				if (m_sons[i].m_isLeaf && !m_sons[i].m_hasConstr) {
					text.append("N" + m_sons[i].m_id + " [label=\"Class="
							+ m_sons[i].m_leafclass + "\" "
							+ "shape=box style=filled");
					text.append("]\n");
				} else {
					text.append("N" + m_sons[i].m_id + " [label=\""
							+ m_sons[i].m_localModel.leftSide(m_train) + "\" ");
					text.append("]\n");
					m_sons[i].graphTree(text);
				}
			}
		}
	}

	/**
	 * Cleanup in order to save memory.
	 */
	public void cleanup() {
		super.cleanup();
		if (!m_isLeaf) {
			for (int i = 0; i < m_sons.length; i++)
				m_sons[i].cleanup();
		}
	}

	/**
	 * Returns the revision string.
	 * 
	 * @return the revision
	 */
	public String getRevision() {
		return RevisionUtils.extract("$Revision: 1.4 $");
	}
}
